Skip to content

Conversation

@Aatman09
Copy link
Contributor

@Aatman09 Aatman09 commented Jan 4, 2026

Resolves #107

Reference
This implementation is based on the following tutorial:
JAX Machine Translation Tutorial

Changes made

  • Added dataclass-based configuration for improved clarity and structure
  • Enhanced the tutorial with additional Markdown explanations for better readability

Notes

  • Key–value (KV) caching has been left out

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@chapman20j
Copy link
Collaborator

Hi @Aatman09 . Thank you for the nice commit. Could you please include a few pip installs at the beginning of the notebook for additional dependencies. Please also include their versions. e.g. ! pip install "grain==0.2.15. Also, please ensure that this notebook runs on colab.

@chapman20j
Copy link
Collaborator

For the KV cache, this would be nice to add in the Use Model For Inference section. Using caching makes the inference faster by allowing attention to re-use the previously computed k and v tensors. This gives you two options

  1. Implement your own caching logic
  2. Change the flags for the attention layers

I think option 2 makes the most sense for this tutorial so it doesn't get too in the weeds on the cache. Implementing your own caching may also require writing your own attention layers. For more details, the nnx docs cover how to initialize a cache (https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html). This can be done with .init_cache or the .set_mode methods. Please let me know if you'd like any further clarification or more discussion around this.

@Aatman09
Copy link
Contributor Author

Aatman09 commented Jan 8, 2026

Thank you for the review I will implement the changes as soon as possible

@Aatman09
Copy link
Contributor Author

I reduced the number of epochs from 10 to 2 for testing, which is why the graph looks different.

@chapman20j
Copy link
Collaborator

Hi @Aatman09 . I see the issue. I'll make some quick general comments here.

  1. Could you use pre-commit as per the contribution guidelines. It makes it easier to review changes from notebooks.
  2. For the pip installs could you update it to have "jax[tpu]==0.8.2" "flax==0.12.2". This will make the versioning work properly when working on Google colab.

@chapman20j
Copy link
Collaborator

For the decode sequence function, it should look closer to this:

def decode_sequence(input_sentence):
    input_sentence = custom_standardization(input_sentence)
    tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)
    encoder_input = jnp.array([tokenized_input_sentence])

    emb_enc = model.positional_embedding(encoder_input)
    encoder_outputs = model.encoder(emb_enc, mask=None)

    dummy_input_shape = (1, 30, model.config.embed_dim) # <- Update the cache size to be sufficiently large. I chose 30 here
    model.init_cache(dummy_input_shape)

    decoded_sentence = "[start"
    current_token_id = tokenizer.encode("[start")
    current_input = jnp.array([current_token_id])

    for i in range(sequence_length):
        logits = model.decode_step(current_input, encoder_outputs, step_index=i)

        sampled_id = np.argmax(logits[0, 0, :]).item()
        sampled_token = tokenizer.decode([sampled_id])

        decoded_sentence += "" + sampled_token # Your implementation had a space here, but this should be an empty string

        if sampled_token == "[end]":
            break

        # Update input for next loop
        current_input = jnp.array([[sampled_id]])

    return decoded_sentence

The main issue is that the kv-cache is too small (only 1 token). With jax, out-of-bounds indexing above the max index will just reduce to the last index resulting in a silent bug. Make sure that the kv-cache is large enough to handle the decoding. I chose 30 here, but this should be decided programmatically (based on the input). English and spanish are tokenized differently so the exact number of output tokens isn't fully clear. Using a multiple of the number of input tokens should suffice (e.g. 2 or 3 times as many tokens for the cache).

Updating the decoded_sentence += ... line makes the output look nicer since words don't exactly correspond to tokens. Adding spaces will put spaces in-between parts of words.

Finally, you'll notice that these changes result in many ! tokens after the [end] token. These can be removed by truncating the string after the [end] token. A better solution is to stop the for loop when you encounter an [end] token.

@chapman20j
Copy link
Collaborator

Also, for the pip installs, using the jax[tpu] makes training fast on google colab. When I use v5e-1 TPU, epoch 1 takes 42 seconds and epoch 2 takes 19 seconds. The first epoch is longer due to the compilation time but we can expect each epoch after the first to take about 19 seconds. Further code profiling and optimization can bring this number down. These optimizations aren't necessary for this tutorial. However, enabling tpu could allow you to run the notebook in about 4-5 minutes with 10 epochs.

@chapman20j chapman20j merged commit 8948fcc into jax-ml:main Jan 15, 2026
3 checks passed
coder0143 pushed a commit to coder0143/bonsai that referenced this pull request Jan 19, 2026
* Refactor tutorial to use dataclass for configuration

* Imeplemented KV caching (WIP)

* final changes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Port Encoder-Decoder Example from Jax AI Stack

2 participants